from datasets import load_dataset
import json

def get_dataset(dataset="gsm8k",load_from_local=True):
    if(load_from_local == False):
        if dataset == "gsm8k":
            ds_train = load_dataset("openai/gsm8k", "main", split="train")
            ds_test = load_dataset("openai/gsm8k", "main", split="test")
            return ds_train["question"], ds_train["answer"], ds_test["question"], ds_test["answer"]
        else:
            raise ValueError("Invalid dataset.")
    elif(load_from_local == True):
        if(dataset.startswith('gsm8k') or dataset.startswith('prm800k') or dataset.startswith("math") or dataset.startswith("mgsm")):
            ds_train = []
            ds_test = []
            with open(f"data/{dataset}/{dataset}_train.jsonl", 'r', encoding='utf-8') as file:
                for line in file:
                    ds_train.append(json.loads(line))    
            with open(f"data/{dataset}/{dataset}_test.jsonl", 'r', encoding='utf-8') as file:
                for line in file:
                    ds_test.append(json.loads(line))

        if isinstance(ds_train, list) and all(isinstance(item, dict) for item in ds_train) and isinstance(ds_test, list) and all(isinstance(item, dict) for item in ds_test):
            return [data["question"] for data in ds_train], [data["answer"] for data in ds_train], [data["question"] for data in ds_test], [data["answer"] for data in ds_test]
        else:
            raise ValueError("JSON file content is not a list of dictionaries")
